!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Calculate the saop potential
! **************************************************************************************************
MODULE xc_pot_saop
   USE atomic_kind_types,               ONLY: atomic_kind_type,&
                                              get_atomic_kind
   USE basis_set_types,                 ONLY: gto_basis_set_type
   USE cp_array_utils,                  ONLY: cp_1d_r_p_type
   USE cp_control_types,                ONLY: dft_control_type
   USE cp_dbcsr_api,                    ONLY: dbcsr_copy,&
                                              dbcsr_deallocate_matrix,&
                                              dbcsr_p_type,&
                                              dbcsr_set
   USE cp_dbcsr_operations,             ONLY: cp_dbcsr_plus_fm_fm_t,&
                                              dbcsr_allocate_matrix_set,&
                                              dbcsr_deallocate_matrix_set
   USE cp_fm_types,                     ONLY: cp_fm_create,&
                                              cp_fm_get_info,&
                                              cp_fm_get_submatrix,&
                                              cp_fm_p_type,&
                                              cp_fm_release,&
                                              cp_fm_set_all,&
                                              cp_fm_set_submatrix,&
                                              cp_fm_type
   USE input_constants,                 ONLY: do_method_gapw,&
                                              oe_gllb,&
                                              oe_lb,&
                                              oe_saop,&
                                              xc_funct_no_shortcut
   USE input_section_types,             ONLY: &
        section_vals_create, section_vals_duplicate, section_vals_get_subs_vals, &
        section_vals_release, section_vals_retain, section_vals_set_subs_vals, section_vals_type, &
        section_vals_val_get, section_vals_val_set
   USE kinds,                           ONLY: dp
   USE mathconstants,                   ONLY: pi
   USE message_passing,                 ONLY: mp_para_env_type
   USE pw_env_types,                    ONLY: pw_env_get,&
                                              pw_env_type
   USE pw_methods,                      ONLY: pw_axpy,&
                                              pw_copy,&
                                              pw_scale,&
                                              pw_zero
   USE pw_pool_types,                   ONLY: pw_pool_type
   USE pw_types,                        ONLY: pw_c1d_gs_type,&
                                              pw_r3d_rs_type
   USE qs_collocate_density,            ONLY: calculate_rho_elec
   USE qs_environment_types,            ONLY: get_qs_env,&
                                              qs_environment_type
   USE qs_gapw_densities,               ONLY: prepare_gapw_den
   USE qs_grid_atom,                    ONLY: grid_atom_type
   USE qs_harmonics_atom,               ONLY: harmonics_atom_type
   USE qs_integrate_potential,          ONLY: integrate_v_rspace
   USE qs_kind_types,                   ONLY: get_qs_kind,&
                                              qs_kind_type
   USE qs_ks_atom,                      ONLY: update_ks_atom
   USE qs_ks_types,                     ONLY: qs_ks_env_type
   USE qs_local_rho_types,              ONLY: local_rho_set_create,&
                                              local_rho_set_release,&
                                              local_rho_type
   USE qs_mo_types,                     ONLY: get_mo_set,&
                                              mo_set_type
   USE qs_neighbor_list_types,          ONLY: neighbor_list_set_p_type
   USE qs_oce_types,                    ONLY: oce_matrix_type
   USE qs_rho_atom_methods,             ONLY: allocate_rho_atom_internals,&
                                              calculate_rho_atom_coeff
   USE qs_rho_atom_types,               ONLY: get_rho_atom,&
                                              rho_atom_coeff,&
                                              rho_atom_type
   USE qs_rho_types,                    ONLY: qs_rho_get,&
                                              qs_rho_type
   USE qs_vxc_atom,                     ONLY: calc_rho_angular,&
                                              gaVxcgb_noGC
   USE util,                            ONLY: get_limit
   USE virial_types,                    ONLY: virial_type
   USE xc,                              ONLY: xc_vxc_pw_create1
   USE xc_atom,                         ONLY: fill_rho_set,&
                                              vxc_of_r_new,&
                                              xc_rho_set_atom_update
   USE xc_derivative_set_types,         ONLY: xc_derivative_set_type,&
                                              xc_dset_create,&
                                              xc_dset_get_derivative,&
                                              xc_dset_release,&
                                              xc_dset_zero_all
   USE xc_derivative_types,             ONLY: xc_derivative_get,&
                                              xc_derivative_type
   USE xc_derivatives,                  ONLY: xc_functionals_eval
   USE xc_rho_cflags_types,             ONLY: xc_rho_cflags_setall,&
                                              xc_rho_cflags_type
   USE xc_rho_set_types,                ONLY: xc_rho_set_create,&
                                              xc_rho_set_release,&
                                              xc_rho_set_type,&
                                              xc_rho_set_update
   USE xc_xbecke88,                     ONLY: xb88_lda_info,&
                                              xb88_lsd_info
#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   PUBLIC :: add_saop_pot

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'xc_pot_saop'

   ! should be eliminated
   REAL(KIND=dp), PARAMETER :: alpha = 1.19_dp, beta = 0.01_dp, K_rho = 0.42_dp
   REAL(KIND=dp), PARAMETER :: kappa = 0.804_dp, mu = 0.21951_dp, &
                               beta_ec = 0.066725_dp, gamma_saop = 0.031091_dp

CONTAINS

! **************************************************************************************************
!> \brief ...
!> \param ks_matrix ...
!> \param qs_env ...
!> \param oe_corr ...
! **************************************************************************************************
   SUBROUTINE add_saop_pot(ks_matrix, qs_env, oe_corr)

      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: ks_matrix
      TYPE(qs_environment_type), POINTER                 :: qs_env
      INTEGER, INTENT(IN)                                :: oe_corr

      INTEGER                                            :: dft_method_id, homo, i, ispin, j, k, &
                                                            nspins, orb, xc_deriv_method_id, &
                                                            xc_rho_smooth_id
      INTEGER, DIMENSION(2)                              :: ncol, nrow
      INTEGER, DIMENSION(2, 3)                           :: bo
      LOGICAL                                            :: compute_virial, gapw, lsd
      REAL(KIND=dp)                                      :: density_cut, efac, gradient_cut, &
                                                            tau_cut, we_GLLB, we_LB, xc_energy
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: coeff_col
      REAL(KIND=dp), DIMENSION(3, 3)                     :: virial_xc_tmp
      REAL(KIND=dp), DIMENSION(:), POINTER               :: mo_eigenvalues
      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: e_uniform
      TYPE(cp_fm_type), ALLOCATABLE, DIMENSION(:)        :: single_mo_coeff
      TYPE(cp_fm_type), POINTER                          :: mo_coeff
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: orbital_density_matrix, rho_struct_ao
      TYPE(mo_set_type), DIMENSION(:), POINTER           :: molecular_orbitals
      TYPE(pw_c1d_gs_type)                               :: orbital_g
      TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER        :: rho_g
      TYPE(pw_env_type), POINTER                         :: pw_env
      TYPE(pw_pool_type), POINTER                        :: auxbas_pw_pool
      TYPE(pw_r3d_rs_type)                               :: orbital
      TYPE(pw_r3d_rs_type), ALLOCATABLE, DIMENSION(:)    :: vxc_GLLB, vxc_SAOP
      TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r, rho_struct_r, tau, vxc_LB, &
                                                            vxc_tau, vxc_tmp
      TYPE(qs_ks_env_type), POINTER                      :: ks_env
      TYPE(qs_rho_type), POINTER                         :: rho_struct
      TYPE(section_vals_type), POINTER                   :: input, xc_fun_section_orig, &
                                                            xc_fun_section_tmp, xc_section_orig, &
                                                            xc_section_tmp
      TYPE(virial_type), POINTER                         :: virial
      TYPE(xc_derivative_set_type)                       :: deriv_set
      TYPE(xc_derivative_type), POINTER                  :: deriv
      TYPE(xc_rho_cflags_type)                           :: needs
      TYPE(xc_rho_set_type)                              :: rho_set

      NULLIFY (ks_env, pw_env, auxbas_pw_pool, input)
      NULLIFY (rho_g, rho_r, tau, rho_struct, e_uniform)
      NULLIFY (vxc_LB, vxc_tmp, vxc_tau)
      NULLIFY (mo_eigenvalues, deriv, rho_struct_r, rho_struct_ao)
      NULLIFY (orbital_density_matrix, xc_section_tmp, xc_fun_section_tmp)

      CALL get_qs_env(qs_env, &
                      ks_env=ks_env, &
                      rho=rho_struct, &
                      pw_env=pw_env, &
                      input=input, &
                      virial=virial, &
                      mos=molecular_orbitals)
      compute_virial = virial%pv_calculate .AND. (.NOT. virial%pv_numer)
      CALL section_vals_val_get(input, "DFT%QS%METHOD", i_val=dft_method_id)
      gapw = (dft_method_id == do_method_gapw)

      xc_section_orig => section_vals_get_subs_vals(input, "DFT%XC")
      CALL section_vals_retain(xc_section_orig)
      CALL section_vals_duplicate(xc_section_orig, xc_section_tmp)

      CALL section_vals_val_get(xc_section_orig, "DENSITY_CUTOFF", &
                                r_val=density_cut)
      CALL section_vals_val_get(xc_section_orig, "GRADIENT_CUTOFF", &
                                r_val=gradient_cut)
      CALL section_vals_val_get(xc_section_orig, "TAU_CUTOFF", &
                                r_val=tau_cut)

      CALL pw_env_get(pw_env, auxbas_pw_pool=auxbas_pw_pool)

      CALL section_vals_val_get(input, "DFT%LSD", l_val=lsd)
      IF (lsd) THEN
         nspins = 2
      ELSE
         nspins = 1
      END IF

      ALLOCATE (single_mo_coeff(nspins))
      CALL dbcsr_allocate_matrix_set(orbital_density_matrix, nspins)
      CALL qs_rho_get(rho_struct, rho_r=rho_struct_r, rho_ao=rho_struct_ao)
      rho_r => rho_struct_r
      DO ispin = 1, nspins
         ALLOCATE (orbital_density_matrix(ispin)%matrix)
         CALL dbcsr_copy(orbital_density_matrix(ispin)%matrix, &
                         rho_struct_ao(ispin)%matrix, "orbital density")
      END DO
      bo = rho_r(1)%pw_grid%bounds_local

      !---------------------------!
      ! create the density needed !
      !---------------------------!
      CALL xc_rho_set_create(rho_set, bo, &
                             density_cut, &
                             gradient_cut, &
                             tau_cut)
      CALL xc_rho_cflags_setall(needs, .FALSE.)
      IF (lsd) THEN
         CALL xb88_lsd_info(needs=needs)
         needs%norm_drho = .TRUE.
      ELSE
         CALL xb88_lda_info(needs=needs)
      END IF
      CALL section_vals_val_get(xc_section_orig, "XC_GRID%XC_DERIV", &
                                i_val=xc_deriv_method_id)
      CALL section_vals_val_get(xc_section_orig, "XC_GRID%XC_SMOOTH_RHO", &
                                i_val=xc_rho_smooth_id)
      CALL xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, &
                             xc_deriv_method_id, &
                             xc_rho_smooth_id, &
                             auxbas_pw_pool)

      !----------------------------------------!
      ! Construct the LB94 potential in vxc_LB !
      !----------------------------------------!
      xc_fun_section_orig => section_vals_get_subs_vals(xc_section_orig, &
                                                        "XC_FUNCTIONAL")
      CALL section_vals_create(xc_fun_section_tmp, xc_fun_section_orig%section)
      CALL section_vals_val_set(xc_fun_section_tmp, "_SECTION_PARAMETERS_", &
                                i_val=xc_funct_no_shortcut)
      CALL section_vals_val_set(xc_fun_section_tmp, "XALPHA%_SECTION_PARAMETERS_", &
                                l_val=.TRUE.)
      CALL section_vals_set_subs_vals(xc_section_tmp, "XC_FUNCTIONAL", &
                                      xc_fun_section_tmp)

      CPASSERT(.NOT. compute_virial)
!     CALL xc_vxc_pw_create(vxc_tmp, vxc_tau, xc_energy, rho_r, rho_g, tau, &
!                           xc_section_tmp, auxbas_pw_pool, &
!                           compute_virial=.FALSE., virial_xc=virial_xc_tmp)
      CALL xc_vxc_pw_create1(vxc_tmp, vxc_tau, rho_r, rho_g, tau, xc_energy, &
                             xc_section_tmp, auxbas_pw_pool, &
                             compute_virial=.FALSE., virial_xc=virial_xc_tmp)

      CALL section_vals_val_set(xc_fun_section_tmp, "XALPHA%_SECTION_PARAMETERS_", &
                                l_val=.FALSE.)
      CALL section_vals_val_set(xc_fun_section_tmp, "PZ81%_SECTION_PARAMETERS_", &
                                l_val=.TRUE.)

      CPASSERT(.NOT. compute_virial)
!     CALL xc_vxc_pw_create(vxc_LB, vxc_tau, xc_energy, rho_r, rho_g, tau, &
!                           xc_section_tmp, auxbas_pw_pool, &
!                           compute_virial=.FALSE., virial_xc=virial_xc_tmp)
      CALL xc_vxc_pw_create1(vxc_LB, vxc_tau, rho_r, rho_g, tau, xc_energy, &
                             xc_section_tmp, auxbas_pw_pool, &
                             compute_virial=.FALSE., virial_xc=virial_xc_tmp)

      DO ispin = 1, nspins
         CALL pw_axpy(vxc_tmp(ispin), vxc_LB(ispin), alpha)
      END DO

      DO ispin = 1, nspins
         CALL add_lb_pot(vxc_tmp(ispin)%array, rho_set, lsd, ispin)
         CALL pw_axpy(vxc_tmp(ispin), vxc_LB(ispin), -1.0_dp)
      END DO

      !-----------------------------------------------------------------------------------!
      ! Construct 2 times PBE one particle density from the PZ correlation energy density !
      !-----------------------------------------------------------------------------------!
      CALL xc_dset_create(deriv_set, local_bounds=bo)
      CALL xc_functionals_eval(xc_fun_section_tmp, &
                               lsd=lsd, &
                               rho_set=rho_set, &
                               deriv_set=deriv_set, &
                               deriv_order=0)

      deriv => xc_dset_get_derivative(deriv_set, [INTEGER::])
      CALL xc_derivative_get(deriv, deriv_data=e_uniform)

      ALLOCATE (vxc_GLLB(nspins))
      DO ispin = 1, nspins
         CALL auxbas_pw_pool%create_pw(vxc_GLLB(ispin))
      END DO

      DO ispin = 1, nspins
         CALL calc_2excpbe(vxc_GLLB(ispin)%array, rho_set, e_uniform, lsd)
      END DO

      CALL xc_dset_release(deriv_set)

      CALL auxbas_pw_pool%create_pw(orbital)
      CALL auxbas_pw_pool%create_pw(orbital_g)

      DO ispin = 1, nspins

         CALL get_mo_set(molecular_orbitals(ispin), &
                         mo_coeff=mo_coeff, &
                         eigenvalues=mo_eigenvalues, &
                         homo=homo)
         CALL cp_fm_create(single_mo_coeff(ispin), &
                           mo_coeff%matrix_struct, &
                           "orbital density matrix")

         CALL cp_fm_get_info(single_mo_coeff(ispin), &
                             nrow_global=nrow(ispin), ncol_global=ncol(ispin))
         ALLOCATE (coeff_col(nrow(ispin), 1))

         CALL pw_zero(vxc_tmp(ispin))

         DO orb = 1, homo - 1

            efac = K_rho*SQRT(mo_eigenvalues(homo) - mo_eigenvalues(orb))
            IF (.NOT. lsd) efac = 2.0_dp*efac

            CALL cp_fm_set_all(single_mo_coeff(ispin), 0.0_dp)
            CALL cp_fm_get_submatrix(mo_coeff, coeff_col, &
                                     1, orb, nrow(ispin), 1)
            CALL cp_fm_set_submatrix(single_mo_coeff(ispin), coeff_col, &
                                     1, orb)
            CALL dbcsr_set(orbital_density_matrix(ispin)%matrix, 0.0_dp)
            CALL cp_dbcsr_plus_fm_fm_t(orbital_density_matrix(ispin)%matrix, &
                                       matrix_v=single_mo_coeff(ispin), &
                                       ncol=ncol(ispin), &
                                       alpha=1.0_dp)
            CALL pw_zero(orbital)
            CALL pw_zero(orbital_g)
            CALL calculate_rho_elec(matrix_p=orbital_density_matrix(ispin)%matrix, &
                                    rho=orbital, rho_gspace=orbital_g, &
                                    ks_env=ks_env)

            CALL pw_axpy(orbital, vxc_tmp(ispin), efac)

         END DO
         DEALLOCATE (coeff_col)

         DO k = bo(1, 3), bo(2, 3)
            DO j = bo(1, 2), bo(2, 2)
               DO i = bo(1, 1), bo(2, 1)
                  IF (rho_r(ispin)%array(i, j, k) > density_cut) THEN
                     vxc_tmp(ispin)%array(i, j, k) = vxc_tmp(ispin)%array(i, j, k)/ &
                                                     rho_r(ispin)%array(i, j, k)
                  ELSE
                     vxc_tmp(ispin)%array(i, j, k) = 0.0_dp
                  END IF
               END DO
            END DO
         END DO

         CALL pw_axpy(vxc_tmp(ispin), vxc_GLLB(ispin), 1.0_dp)

      END DO

      !---------------!
      ! Assemble SAOP !
      !---------------!
      ALLOCATE (vxc_SAOP(nspins))

      DO ispin = 1, nspins

         CALL get_mo_set(molecular_orbitals(ispin), &
                         mo_coeff=mo_coeff, &
                         eigenvalues=mo_eigenvalues, &
                         homo=homo)
         CALL auxbas_pw_pool%create_pw(vxc_SAOP(ispin))
         CALL pw_zero(vxc_SAOP(ispin))

         ALLOCATE (coeff_col(nrow(ispin), 1))

         DO orb = 1, homo

            we_LB = EXP(-2.0_dp*(mo_eigenvalues(homo) - mo_eigenvalues(orb))**2)
            we_GLLB = 1.0_dp - we_LB
            IF (.NOT. lsd) THEN
               we_LB = 2.0_dp*we_LB
               we_GLLB = 2.0_dp*we_GLLB
            END IF

            vxc_tmp(ispin)%array = we_LB*vxc_LB(ispin)%array + &
                                   we_GLLB*vxc_GLLB(ispin)%array

            CALL cp_fm_set_all(single_mo_coeff(ispin), 0.0_dp)
            CALL cp_fm_get_submatrix(mo_coeff, coeff_col, &
                                     1, orb, nrow(ispin), 1)
            CALL cp_fm_set_submatrix(single_mo_coeff(ispin), coeff_col, &
                                     1, orb)
            CALL dbcsr_set(orbital_density_matrix(ispin)%matrix, 0.0_dp)
            CALL cp_dbcsr_plus_fm_fm_t(orbital_density_matrix(ispin)%matrix, &
                                       matrix_v=single_mo_coeff(ispin), &
                                       ncol=ncol(ispin), &
                                       alpha=1.0_dp)
            CALL pw_zero(orbital)
            CALL pw_zero(orbital_g)
            CALL calculate_rho_elec(matrix_p=orbital_density_matrix(ispin)%matrix, &
                                    rho=orbital, rho_gspace=orbital_g, &
                                    ks_env=ks_env)

            vxc_SAOP(ispin)%array = vxc_SAOP(ispin)%array + &
                                    orbital%array*vxc_tmp(ispin)%array

         END DO

         CALL dbcsr_deallocate_matrix(orbital_density_matrix(ispin)%matrix)

         DEALLOCATE (coeff_col)

         DO k = bo(1, 3), bo(2, 3)
            DO j = bo(1, 2), bo(2, 2)
               DO i = bo(1, 1), bo(2, 1)
                  IF (rho_r(ispin)%array(i, j, k) > density_cut) THEN
                     vxc_SAOP(ispin)%array(i, j, k) = vxc_SAOP(ispin)%array(i, j, k)/ &
                                                      rho_r(ispin)%array(i, j, k)
                  ELSE
                     vxc_SAOP(ispin)%array(i, j, k) = 0.0_dp
                  END IF
               END DO
            END DO
         END DO

      END DO

      CALL cp_fm_release(single_mo_coeff)

      CALL xc_rho_set_release(rho_set, auxbas_pw_pool)
      CALL auxbas_pw_pool%give_back_pw(orbital)
      CALL auxbas_pw_pool%give_back_pw(orbital_g)

      !--------------------!
      ! Do the integration !
      !--------------------!
      DO ispin = 1, nspins

         IF (oe_corr == oe_lb) THEN
            CALL pw_copy(vxc_LB(ispin), vxc_SAOP(ispin))
         ELSE IF (oe_corr == oe_gllb) THEN
            CALL pw_copy(vxc_GLLB(ispin), vxc_SAOP(ispin))
         END IF
         CALL pw_scale(vxc_SAOP(ispin), vxc_SAOP(ispin)%pw_grid%dvol)

         CALL integrate_v_rspace(v_rspace=vxc_SAOP(ispin), pmat=rho_struct_ao(ispin), &
                                 hmat=ks_matrix(ispin), qs_env=qs_env, &
                                 calculate_forces=.FALSE., &
                                 gapw=gapw)

      END DO

      DO ispin = 1, nspins
         CALL auxbas_pw_pool%give_back_pw(vxc_SAOP(ispin))
         CALL auxbas_pw_pool%give_back_pw(vxc_GLLB(ispin))
         CALL vxc_LB(ispin)%release()
         CALL vxc_tmp(ispin)%release()
      END DO
      DEALLOCATE (vxc_GLLB, vxc_LB, vxc_tmp, orbital_density_matrix)

      DEALLOCATE (vxc_SAOP)

      CALL section_vals_release(xc_fun_section_tmp)
      CALL section_vals_release(xc_section_tmp)
      CALL section_vals_release(xc_section_orig)

      !-----------------------!
      ! Call the GAPW routine !
      !-----------------------!
      IF (gapw) THEN
         CALL gapw_add_atomic_saop_pot(qs_env, oe_corr)
      END IF

   END SUBROUTINE add_saop_pot

! **************************************************************************************************
!> \brief ...
!> \param qs_env ...
!> \param oe_corr ...
! **************************************************************************************************
   SUBROUTINE gapw_add_atomic_saop_pot(qs_env, oe_corr)

      TYPE(qs_environment_type), POINTER                 :: qs_env
      INTEGER, INTENT(IN)                                :: oe_corr

      INTEGER                                            :: ia, iat, iatom, ikind, ir, ispin, na, &
                                                            natom, nr, ns, nspins, orb
      INTEGER, DIMENSION(2)                              :: bo, homo, ncol, nrow
      INTEGER, DIMENSION(2, 3)                           :: bounds
      INTEGER, DIMENSION(:), POINTER                     :: atom_list
      LOGICAL                                            :: lsd, paw_atom
      REAL(dp), DIMENSION(:, :, :), POINTER              :: tau
      REAL(KIND=dp)                                      :: density_cut, efac, exc, gradient_cut, &
                                                            tau_cut, we_GLLB, we_LB
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: coeff_col
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: weight
      REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: dummy, e_uniform, rho_h, rho_s, vtau, &
         vxc_GLLB_h, vxc_GLLB_s, vxc_LB_h, vxc_LB_s, vxc_SAOP_h, vxc_SAOP_s, vxc_tmp_h, vxc_tmp_s
      REAL(KIND=dp), DIMENSION(:, :, :, :), POINTER      :: drho_h, drho_s, vxg
      TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      TYPE(cp_1d_r_p_type), DIMENSION(:), POINTER        :: mo_eigenvalues
      TYPE(cp_fm_p_type), ALLOCATABLE, DIMENSION(:)      :: mo_coeff
      TYPE(cp_fm_type), ALLOCATABLE, DIMENSION(:)        :: single_mo_coeff
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_ks, orbital_density_matrix, &
                                                            rho_struct_ao
      TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: ksmat, psmat
      TYPE(dft_control_type), POINTER                    :: dft_control
      TYPE(grid_atom_type), POINTER                      :: atomic_grid, grid_atom
      TYPE(gto_basis_set_type), POINTER                  :: orb_basis
      TYPE(harmonics_atom_type), POINTER                 :: harmonics
      TYPE(local_rho_type), POINTER                      :: local_rho_set
      TYPE(mo_set_type), DIMENSION(:), POINTER           :: molecular_orbitals
      TYPE(mp_para_env_type), POINTER                    :: para_env
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sab
      TYPE(oce_matrix_type), POINTER                     :: oce
      TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      TYPE(qs_rho_type), POINTER                         :: rho_structure
      TYPE(rho_atom_coeff), DIMENSION(:), POINTER        :: dr_h, dr_s, int_hh, int_ss, r_h, r_s
      TYPE(rho_atom_coeff), DIMENSION(:, :), POINTER     :: r_h_d, r_s_d
      TYPE(rho_atom_type), DIMENSION(:), POINTER         :: rho_atom_set
      TYPE(rho_atom_type), POINTER                       :: rho_atom
      TYPE(section_vals_type), POINTER                   :: input, xc_fun_section_orig, &
                                                            xc_fun_section_tmp, xc_section_orig, &
                                                            xc_section_tmp
      TYPE(xc_derivative_set_type)                       :: deriv_set
      TYPE(xc_derivative_type), POINTER                  :: deriv
      TYPE(xc_rho_cflags_type)                           :: needs, needs_orbs
      TYPE(xc_rho_set_type)                              :: orb_rho_set_h, orb_rho_set_s, rho_set_h, &
                                                            rho_set_s

      NULLIFY (weight, rho_h, rho_s, vxc_LB_h, vxc_LB_s, vxc_GLLB_h, vxc_GLLB_s, &
               vxc_tmp_h, vxc_tmp_s, vtau, dummy, e_uniform, drho_h, drho_s, vxg, atom_list, &
               atomic_kind_set, qs_kind_set, deriv, atomic_grid, rho_struct_ao, &
               harmonics, molecular_orbitals, rho_structure, r_h, r_s, dr_h, dr_s, &
               r_h_d, r_s_d, rho_atom_set, rho_atom, para_env, &
               mo_eigenvalues, local_rho_set, matrix_ks, &
               orbital_density_matrix, vxc_SAOP_h, vxc_SAOP_s)

      ! tau is needed for fill_rho_set, but should never be used
      NULLIFY (tau)
      NULLIFY (dft_control, oce, sab)

      CALL get_qs_env(qs_env, input=input, &
                      rho=rho_structure, &
                      mos=molecular_orbitals, &
                      atomic_kind_set=atomic_kind_set, &
                      qs_kind_set=qs_kind_set, &
                      rho_atom_set=rho_atom_set, &
                      matrix_ks=matrix_ks, &
                      dft_control=dft_control, &
                      para_env=para_env, &
                      oce=oce, sab_orb=sab)

      CALL qs_rho_get(rho_structure, rho_ao=rho_struct_ao)

      xc_section_orig => section_vals_get_subs_vals(input, "DFT%XC")
      CALL section_vals_retain(xc_section_orig)
      CALL section_vals_duplicate(xc_section_orig, xc_section_tmp)

      ! [SC] the following code can be traced back to SVN rev. 4296 (git:f97138b) that
      !      has removed the component 'nspins' from the derived type 'dft_control_type'.
      !      Is it worth to remove the code below in favour of 'dft_control%nspins'
      !      since its reintroduction? Note that in case of ROKS calculations,
      !      'lsd == .FALSE.' but 'dft_control%nspins == 2'.
      CALL section_vals_val_get(input, "DFT%LSD", l_val=lsd)
      IF (lsd) THEN
         nspins = 2
      ELSE
         nspins = 1
      END IF

      CALL section_vals_val_get(xc_section_orig, "DENSITY_CUTOFF", &
                                r_val=density_cut)
      CALL section_vals_val_get(xc_section_orig, "GRADIENT_CUTOFF", &
                                r_val=gradient_cut)
      CALL section_vals_val_get(xc_section_orig, "TAU_CUTOFF", &
                                r_val=tau_cut)

      ! remap pointer
      ns = SIZE(rho_struct_ao)
      psmat(1:ns, 1:1) => rho_struct_ao(1:ns)
      CALL calculate_rho_atom_coeff(qs_env, psmat, rho_atom_set, qs_kind_set, oce, sab, para_env)
      CALL prepare_gapw_den(qs_env)

      ALLOCATE (mo_coeff(nspins), single_mo_coeff(nspins), mo_eigenvalues(nspins))

      CALL dbcsr_allocate_matrix_set(orbital_density_matrix, nspins)

      DO ispin = 1, nspins
         CALL get_mo_set(molecular_orbitals(ispin), &
                         mo_coeff=mo_coeff(ispin)%matrix, &
                         eigenvalues=mo_eigenvalues(ispin)%array, &
                         homo=homo(ispin))
         CALL cp_fm_create(single_mo_coeff(ispin), &
                           mo_coeff(ispin)%matrix%matrix_struct, &
                           "orbital density matrix")
         CALL cp_fm_get_info(single_mo_coeff(ispin), &
                             nrow_global=nrow(ispin), ncol_global=ncol(ispin))
         ALLOCATE (orbital_density_matrix(ispin)%matrix)
         CALL dbcsr_copy(orbital_density_matrix(ispin)%matrix, &
                         rho_struct_ao(ispin)%matrix, &
                         "orbital density")
      END DO
      CALL local_rho_set_create(local_rho_set)
      CALL allocate_rho_atom_internals(local_rho_set%rho_atom_set, atomic_kind_set, &
                                       qs_kind_set, dft_control, para_env)

      DO ikind = 1, SIZE(atomic_kind_set)
         CALL get_atomic_kind(atomic_kind_set(ikind), atom_list=atom_list, natom=natom)

         CALL get_qs_kind(qs_kind_set(ikind), paw_atom=paw_atom, &
                          harmonics=harmonics, grid_atom=atomic_grid)
         IF (.NOT. paw_atom) CYCLE

         nr = atomic_grid%nr
         na = atomic_grid%ng_sphere
         bounds(1:2, 1:3) = 1
         bounds(2, 1) = na
         bounds(2, 2) = nr

         CALL xc_dset_create(deriv_set, local_bounds=bounds)

         CALL xc_rho_set_create(rho_set_h, bounds, density_cut, &
                                gradient_cut, tau_cut)
         CALL xc_rho_set_create(rho_set_s, bounds, density_cut, &
                                gradient_cut, tau_cut)
         CALL xc_rho_set_create(orb_rho_set_h, bounds, density_cut, &
                                gradient_cut, tau_cut)
         CALL xc_rho_set_create(orb_rho_set_s, bounds, density_cut, &
                                gradient_cut, tau_cut)

         CALL xc_rho_cflags_setall(needs, .FALSE.)
         IF (lsd) THEN
            CALL xb88_lsd_info(needs=needs)
            needs%norm_drho = .TRUE.
         ELSE
            CALL xb88_lda_info(needs=needs)
         END IF
         CALL xc_rho_set_atom_update(rho_set_h, needs, nspins, bounds)
         CALL xc_rho_set_atom_update(rho_set_s, needs, nspins, bounds)
         CALL xc_rho_cflags_setall(needs_orbs, .FALSE.)
         needs_orbs%rho = .TRUE.
         IF (lsd) needs_orbs%rho_spin = .TRUE.
         CALL xc_rho_set_atom_update(orb_rho_set_h, needs, nspins, bounds)
         CALL xc_rho_set_atom_update(orb_rho_set_s, needs, nspins, bounds)

         ALLOCATE (rho_h(1:na, 1:nr, 1:nspins), rho_s(1:na, 1:nr, 1:nspins))
         ALLOCATE (weight(1:na, 1:nr))
         ALLOCATE (vxc_LB_h(1:na, 1:nr, 1:nspins), vxc_LB_s(1:na, 1:nr, 1:nspins))
         ALLOCATE (vxc_GLLB_h(1:na, 1:nr, 1:nspins), vxc_GLLB_s(1:na, 1:nr, 1:nspins))
         ALLOCATE (vxc_tmp_h(1:na, 1:nr, 1:nspins), vxc_tmp_s(1:na, 1:nr, 1:nspins))
         ALLOCATE (vxc_SAOP_h(1:na, 1:nr, 1:nspins), vxc_SAOP_s(1:na, 1:nr, 1:nspins))
         ALLOCATE (drho_h(1:4, 1:na, 1:nr, 1:nspins), drho_s(1:4, 1:na, 1:nr, 1:nspins))

         ! Distribute the atoms of this kind
         bo = get_limit(natom, para_env%num_pe, para_env%mepos)

         DO iat = 1, natom !bo(1),bo(2)
            iatom = atom_list(iat)

            rho_atom => rho_atom_set(iatom)
            NULLIFY (r_h, r_s, dr_h, dr_s, r_h_d, r_s_d)
            CALL get_rho_atom(rho_atom=rho_atom, rho_rad_h=r_h, &
                              rho_rad_s=r_s, drho_rad_h=dr_h, &
                              drho_rad_s=dr_s, rho_rad_h_d=r_h_d, &
                              rho_rad_s_d=r_s_d)
            rho_h = 0.0_dp
            rho_s = 0.0_dp
            drho_h = 0.0_dp
            drho_s = 0.0_dp
            DO ir = 1, nr
               CALL calc_rho_angular(atomic_grid, harmonics, nspins, .TRUE., &
                                     ir, r_h, r_s, rho_h, rho_s, &
                                     dr_h, dr_s, r_h_d, r_s_d, drho_h, drho_s)
            END DO
            DO ir = 1, nr
               CALL fill_rho_set(rho_set_h, lsd, nspins, needs, rho_h, drho_h, tau, na, ir)
               CALL fill_rho_set(rho_set_s, lsd, nspins, needs, rho_s, drho_s, tau, na, ir)
            END DO
            DO ir = 1, nr
               DO ia = 1, na
                  weight(ia, ir) = atomic_grid%wr(ir)*atomic_grid%wa(ia)
               END DO
            END DO

            !-----------------------------!
            ! 1. Slater exchange for LB94 !
            !-----------------------------!
            xc_fun_section_orig => section_vals_get_subs_vals(xc_section_orig, &
                                                              "XC_FUNCTIONAL")
            CALL section_vals_create(xc_fun_section_tmp, xc_fun_section_orig%section)
            CALL section_vals_val_set(xc_fun_section_tmp, "_SECTION_PARAMETERS_", &
                                      i_val=xc_funct_no_shortcut)
            CALL section_vals_val_set(xc_fun_section_tmp, "XALPHA%_SECTION_PARAMETERS_", &
                                      l_val=.TRUE.)
            CALL section_vals_set_subs_vals(xc_section_tmp, "XC_FUNCTIONAL", &
                                            xc_fun_section_tmp)

            !---------------------!
            ! Both: hard and soft !
            !---------------------!
            CALL xc_dset_zero_all(deriv_set)
            CALL vxc_of_r_new(xc_fun_section_tmp, rho_set_h, deriv_set, 1, needs, &
                              weight, lsd, na, nr, exc, vxc_tmp_h, vxg, vtau)
            CALL xc_dset_zero_all(deriv_set)
            CALL vxc_of_r_new(xc_fun_section_tmp, rho_set_s, deriv_set, 1, needs, &
                              weight, lsd, na, nr, exc, vxc_tmp_s, vxg, vtau)

            !-------------------------------------------!
            ! 2. PZ correlation for LB94 and ec_uniform !
            !-------------------------------------------!
            CALL section_vals_val_set(xc_fun_section_tmp, "XALPHA%_SECTION_PARAMETERS_", &
                                      l_val=.FALSE.)
            CALL section_vals_val_set(xc_fun_section_tmp, "PZ81%_SECTION_PARAMETERS_", &
                                      l_val=.TRUE.)

            !------!
            ! Hard !
            !------!
            CALL xc_dset_zero_all(deriv_set)
            CALL vxc_of_r_new(xc_fun_section_tmp, rho_set_h, deriv_set, 1, needs, &
                              weight, lsd, na, nr, exc, vxc_LB_h, vxg, vtau)
            vxc_LB_h = vxc_LB_h + alpha*vxc_tmp_h
            DO ispin = 1, nspins
               dummy => vxc_tmp_h(:, :, ispin:ispin)
               CALL add_lb_pot(dummy, rho_set_h, lsd, ispin)
               vxc_LB_h(:, :, ispin) = vxc_LB_h(:, :, ispin) - weight(:, :)*vxc_tmp_h(:, :, ispin)
            END DO
            NULLIFY (dummy)

            vxc_GLLB_h = 0.0_dp
            deriv => xc_dset_get_derivative(deriv_set, [INTEGER::])
            CPASSERT(ASSOCIATED(deriv))
            CALL xc_derivative_get(deriv, deriv_data=e_uniform)
            DO ispin = 1, nspins
               dummy => vxc_GLLB_h(:, :, ispin:ispin)
               CALL calc_2excpbe(dummy, rho_set_h, e_uniform, lsd)
               vxc_GLLB_h(:, :, ispin) = vxc_GLLB_h(:, :, ispin)*weight(:, :)
            END DO
            NULLIFY (deriv, dummy, e_uniform)

            !------!
            ! Soft !
            !------!
            CALL xc_dset_zero_all(deriv_set)
            CALL vxc_of_r_new(xc_fun_section_tmp, rho_set_s, deriv_set, 1, needs, &
                              weight, lsd, na, nr, exc, vxc_LB_s, vxg, vtau)

            vxc_LB_s = vxc_LB_s + alpha*vxc_tmp_s
            DO ispin = 1, nspins
               dummy => vxc_tmp_s(:, :, ispin:ispin)
               CALL add_lb_pot(dummy, rho_set_s, lsd, ispin)
               vxc_LB_s(:, :, ispin) = vxc_LB_s(:, :, ispin) - weight(:, :)*vxc_tmp_s(:, :, ispin)
            END DO
            NULLIFY (dummy)

            vxc_GLLB_s = 0.0_dp
            deriv => xc_dset_get_derivative(deriv_set, [INTEGER::])
            CPASSERT(ASSOCIATED(deriv))
            CALL xc_derivative_get(deriv, deriv_data=e_uniform)
            DO ispin = 1, nspins
               dummy => vxc_GLLB_s(:, :, ispin:ispin)
               CALL calc_2excpbe(dummy, rho_set_s, e_uniform, lsd)
               vxc_GLLB_s(:, :, ispin) = vxc_GLLB_s(:, :, ispin)*weight(:, :)
            END DO
            NULLIFY (deriv, dummy, e_uniform)

            !------------------!
            ! Now the orbitals !
            !------------------!
            vxc_tmp_h = 0.0_dp; vxc_tmp_s = 0.0_dp

            DO ispin = 1, nspins

               DO orb = 1, homo(ispin) - 1

                  ALLOCATE (coeff_col(nrow(ispin), 1))

                  efac = K_rho*SQRT(mo_eigenvalues(ispin)%array(homo(ispin)) - &
                                    mo_eigenvalues(ispin)%array(orb))
                  IF (.NOT. lsd) efac = 2.0_dp*efac

                  CALL cp_fm_set_all(single_mo_coeff(ispin), 0.0_dp)
                  CALL cp_fm_get_submatrix(mo_coeff(ispin)%matrix, coeff_col, &
                                           1, orb, nrow(ispin), 1)
                  CALL cp_fm_set_submatrix(single_mo_coeff(ispin), coeff_col, &
                                           1, orb)
                  CALL dbcsr_set(orbital_density_matrix(ispin)%matrix, 0.0_dp)
                  CALL cp_dbcsr_plus_fm_fm_t(orbital_density_matrix(ispin)%matrix, &
                                             matrix_v=single_mo_coeff(ispin), &
                                             ncol=ncol(ispin), &
                                             alpha=1.0_dp)

                  DEALLOCATE (coeff_col)

                  ! This calculates the CPC and density on the grids for every atom even though
                  ! we need it only for iatom at the moment. It seems that to circumvent this,
                  ! the routines must be adapted to calculate just iatom
                  ! remap pointer
                  ns = SIZE(orbital_density_matrix)
                  psmat(1:ns, 1:1) => orbital_density_matrix(1:ns)
                  CALL calculate_rho_atom_coeff(qs_env, psmat, local_rho_set%rho_atom_set, qs_kind_set, oce, sab, para_env)
                  CALL prepare_gapw_den(qs_env, local_rho_set, .FALSE.)

                  rho_atom => local_rho_set%rho_atom_set(iatom)
                  NULLIFY (r_h, r_s, dr_h, dr_s, r_h_d, r_s_d)
                  CALL get_rho_atom(rho_atom=rho_atom, rho_rad_h=r_h, rho_rad_s=r_s)
                  rho_h = 0.0_dp
                  rho_s = 0.0_dp
                  drho_h = 0.0_dp
                  drho_s = 0.0_dp
                  DO ir = 1, nr
                     CALL calc_rho_angular(atomic_grid, harmonics, nspins, .FALSE., &
                                           ir, r_h, r_s, rho_h, rho_s, &
                                           dr_h, dr_s, r_h_d, r_s_d, drho_h, drho_s)
                  END DO
                  DO ir = 1, nr
                     CALL fill_rho_set(orb_rho_set_h, lsd, nspins, needs_orbs, rho_h, drho_h, tau, na, ir)
                     CALL fill_rho_set(orb_rho_set_s, lsd, nspins, needs_orbs, rho_s, drho_s, tau, na, ir)
                  END DO

                  IF (lsd) THEN
                     IF (ispin == 1) THEN
                        vxc_tmp_h(:, :, 1) = vxc_tmp_h(:, :, 1) + efac*orb_rho_set_h%rhoa(:, :, 1)
                        vxc_tmp_s(:, :, 1) = vxc_tmp_s(:, :, 1) + efac*orb_rho_set_s%rhoa(:, :, 1)
                     ELSE
                        vxc_tmp_h(:, :, 2) = vxc_tmp_h(:, :, 2) + efac*orb_rho_set_h%rhob(:, :, 1)
                        vxc_tmp_s(:, :, 2) = vxc_tmp_s(:, :, 2) + efac*orb_rho_set_s%rhob(:, :, 1)
                     END IF
                  ELSE
                     vxc_tmp_h(:, :, 1) = vxc_tmp_h(:, :, 1) + efac*orb_rho_set_h%rho(:, :, 1)
                     vxc_tmp_s(:, :, 1) = vxc_tmp_s(:, :, 1) + efac*orb_rho_set_s%rho(:, :, 1)
                  END IF

               END DO ! orb

            END DO ! ispin

            IF (lsd) THEN
               DO ir = 1, nr
                  DO ia = 1, na
                     IF (rho_set_h%rhoa(ia, ir, 1) > rho_set_h%rho_cutoff) &
                        vxc_GLLB_h(ia, ir, 1) = vxc_GLLB_h(ia, ir, 1) + &
                                                weight(ia, ir)*vxc_tmp_h(ia, ir, 1)/rho_set_h%rhoa(ia, ir, 1)
                     IF (rho_set_h%rhob(ia, ir, 1) > rho_set_h%rho_cutoff) &
                        vxc_GLLB_h(ia, ir, 2) = vxc_GLLB_h(ia, ir, 2) + &
                                                weight(ia, ir)*vxc_tmp_h(ia, ir, 2)/rho_set_h%rhob(ia, ir, 1)
                     IF (rho_set_s%rhoa(ia, ir, 1) > rho_set_s%rho_cutoff) &
                        vxc_GLLB_s(ia, ir, 1) = vxc_GLLB_s(ia, ir, 1) + &
                                                weight(ia, ir)*vxc_tmp_s(ia, ir, 1)/rho_set_s%rhoa(ia, ir, 1)
                     IF (rho_set_s%rhob(ia, ir, 1) > rho_set_s%rho_cutoff) &
                        vxc_GLLB_s(ia, ir, 2) = vxc_GLLB_s(ia, ir, 2) + &
                                                weight(ia, ir)*vxc_tmp_s(ia, ir, 2)/rho_set_s%rhob(ia, ir, 1)
                  END DO
               END DO
            ELSE
               DO ir = 1, nr
                  DO ia = 1, na
                     IF (rho_set_h%rho(ia, ir, 1) > rho_set_h%rho_cutoff) &
                        vxc_GLLB_h(ia, ir, 1) = vxc_GLLB_h(ia, ir, 1) + &
                                                weight(ia, ir)*vxc_tmp_h(ia, ir, 1)/rho_set_h%rho(ia, ir, 1)
                     IF (rho_set_s%rho(ia, ir, 1) > rho_set_s%rho_cutoff) &
                        vxc_GLLB_s(ia, ir, 1) = vxc_GLLB_s(ia, ir, 1) + &
                                                weight(ia, ir)*vxc_tmp_s(ia, ir, 1)/rho_set_s%rho(ia, ir, 1)
                  END DO
               END DO
            END IF

            vxc_SAOP_h = 0.0_dp; vxc_SAOP_s = 0.0_dp

            DO ispin = 1, nspins

               DO orb = 1, homo(ispin)

                  ALLOCATE (coeff_col(nrow(ispin), 1))

                  we_LB = EXP(-2.0_dp*(mo_eigenvalues(ispin)%array(homo(ispin)) - &
                                       mo_eigenvalues(ispin)%array(orb))**2)
                  we_GLLB = 1.0_dp - we_LB
                  IF (.NOT. lsd) THEN
                     we_LB = 2.0_dp*we_LB
                     we_GLLB = 2.0_dp*we_GLLB
                  END IF

                  vxc_tmp_h(:, :, ispin) = we_LB*vxc_LB_h(:, :, ispin) + &
                                           we_GLLB*vxc_GLLB_h(:, :, ispin)
                  vxc_tmp_s(:, :, ispin) = we_LB*vxc_LB_s(:, :, ispin) + &
                                           we_GLLB*vxc_GLLB_s(:, :, ispin)

                  CALL cp_fm_set_all(single_mo_coeff(ispin), 0.0_dp)
                  CALL cp_fm_get_submatrix(mo_coeff(ispin)%matrix, coeff_col, &
                                           1, orb, nrow(ispin), 1)
                  CALL cp_fm_set_submatrix(single_mo_coeff(ispin), coeff_col, &
                                           1, orb)
                  CALL dbcsr_set(orbital_density_matrix(ispin)%matrix, 0.0_dp)
                  CALL cp_dbcsr_plus_fm_fm_t(orbital_density_matrix(ispin)%matrix, &
                                             matrix_v=single_mo_coeff(ispin), &
                                             ncol=ncol(ispin), &
                                             alpha=1.0_dp)

                  DEALLOCATE (coeff_col)

                  ! This calculates the CPC and density on the grids for every atom even though
                  ! we need it only for iatom at the moment. It seems that to circumvent this,
                  ! the routines must be adapted to calculate just iatom
                  ! remap pointer
                  ns = SIZE(orbital_density_matrix)
                  psmat(1:ns, 1:1) => orbital_density_matrix(1:ns)
                  CALL calculate_rho_atom_coeff(qs_env, psmat, local_rho_set%rho_atom_set, qs_kind_set, oce, sab, para_env)
                  CALL prepare_gapw_den(qs_env, local_rho_set, .FALSE.)

                  rho_atom => local_rho_set%rho_atom_set(iatom)
                  NULLIFY (r_h, r_s, dr_h, dr_s, r_h_d, r_s_d)
                  CALL get_rho_atom(rho_atom=rho_atom, rho_rad_h=r_h, rho_rad_s=r_s)
                  rho_h = 0.0_dp
                  rho_s = 0.0_dp
                  drho_h = 0.0_dp
                  drho_s = 0.0_dp
                  DO ir = 1, nr
                     CALL calc_rho_angular(atomic_grid, harmonics, nspins, .FALSE., &
                                           ir, r_h, r_s, rho_h, rho_s, &
                                           dr_h, dr_s, r_h_d, r_s_d, drho_h, drho_s)
                  END DO
                  DO ir = 1, nr
                     CALL fill_rho_set(orb_rho_set_h, lsd, nspins, needs_orbs, rho_h, drho_h, tau, na, ir)
                     CALL fill_rho_set(orb_rho_set_s, lsd, nspins, needs_orbs, rho_s, drho_s, tau, na, ir)
                  END DO

                  IF (lsd) THEN
                     IF (ispin == 1) THEN
                        vxc_SAOP_h(:, :, 1) = vxc_SAOP_h(:, :, 1) + vxc_tmp_h(:, :, 1)*orb_rho_set_h%rhoa(:, :, 1)
                        vxc_SAOP_s(:, :, 1) = vxc_SAOP_s(:, :, 1) + vxc_tmp_s(:, :, 1)*orb_rho_set_s%rhoa(:, :, 1)
                     ELSE
                        vxc_SAOP_h(:, :, 2) = vxc_SAOP_h(:, :, 2) + vxc_tmp_h(:, :, 2)*orb_rho_set_h%rhob(:, :, 1)
                        vxc_SAOP_s(:, :, 2) = vxc_SAOP_s(:, :, 2) + vxc_tmp_s(:, :, 2)*orb_rho_set_s%rhob(:, :, 1)
                     END IF
                  ELSE
                     vxc_SAOP_h(:, :, 1) = vxc_SAOP_h(:, :, 1) + vxc_tmp_h(:, :, 1)*orb_rho_set_h%rho(:, :, 1)
                     vxc_SAOP_s(:, :, 1) = vxc_SAOP_s(:, :, 1) + vxc_tmp_s(:, :, 1)*orb_rho_set_s%rho(:, :, 1)
                  END IF

               END DO ! orb

            END DO ! ispin

            IF (lsd) THEN
               DO ir = 1, nr
                  DO ia = 1, na
                     IF (rho_set_h%rhoa(ia, ir, 1) > rho_set_h%rho_cutoff) THEN
                        vxc_SAOP_h(ia, ir, 1) = vxc_SAOP_h(ia, ir, 1)/rho_set_h%rhoa(ia, ir, 1)
                     ELSE
                        vxc_SAOP_h(ia, ir, 1) = 0.0_dp
                     END IF
                     IF (rho_set_h%rhob(ia, ir, 1) > rho_set_h%rho_cutoff) THEN
                        vxc_SAOP_h(ia, ir, 2) = vxc_SAOP_h(ia, ir, 2)/rho_set_h%rhob(ia, ir, 1)
                     ELSE
                        vxc_SAOP_h(ia, ir, 2) = 0.0_dp
                     END IF
                     IF (rho_set_s%rhoa(ia, ir, 1) > rho_set_s%rho_cutoff) THEN
                        vxc_SAOP_s(ia, ir, 1) = vxc_SAOP_s(ia, ir, 1)/rho_set_s%rhoa(ia, ir, 1)
                     ELSE
                        vxc_SAOP_s(ia, ir, 1) = 0.0_dp
                     END IF
                     IF (rho_set_s%rhob(ia, ir, 1) > rho_set_s%rho_cutoff) THEN
                        vxc_SAOP_s(ia, ir, 2) = vxc_SAOP_s(ia, ir, 2)/rho_set_s%rhob(ia, ir, 1)
                     ELSE
                        vxc_SAOP_s(ia, ir, 2) = 0.0_dp
                     END IF
                  END DO
               END DO
            ELSE
               DO ir = 1, nr
                  DO ia = 1, na
                     IF (rho_set_h%rho(ia, ir, 1) > rho_set_h%rho_cutoff) THEN
                        vxc_SAOP_h(ia, ir, 1) = vxc_SAOP_h(ia, ir, 1)/rho_set_h%rho(ia, ir, 1)
                     ELSE
                        vxc_SAOP_h(ia, ir, 1) = 0.0_dp
                     END IF
                     IF (rho_set_s%rho(ia, ir, 1) > rho_set_s%rho_cutoff) THEN
                        vxc_SAOP_s(ia, ir, 1) = vxc_SAOP_s(ia, ir, 1)/rho_set_s%rho(ia, ir, 1)
                     ELSE
                        vxc_SAOP_s(ia, ir, 1) = 0.0_dp
                     END IF
                  END DO
               END DO
            END IF

            rho_atom => rho_atom_set(iatom)
            CALL get_rho_atom(rho_atom=rho_atom, ga_Vlocal_gb_h=int_hh, ga_Vlocal_gb_s=int_ss)
            CALL get_qs_kind(qs_kind_set(ikind), basis_set=orb_basis, &
                             harmonics=harmonics, grid_atom=grid_atom)
            SELECT CASE (oe_corr)
            CASE (oe_lb)
               CALL gaVxcgb_noGC(vxc_LB_h, vxc_LB_s, int_hh, int_ss, grid_atom, orb_basis, harmonics, nspins)
            CASE (oe_gllb)
               CALL gaVxcgb_noGC(vxc_GLLB_h, vxc_GLLB_s, int_hh, int_ss, grid_atom, orb_basis, harmonics, nspins)
            CASE (oe_saop)
               CALL gaVxcgb_noGC(vxc_SAOP_h, vxc_SAOP_s, int_hh, int_ss, grid_atom, orb_basis, harmonics, nspins)
            CASE default
               CPABORT("Unknown correction!")
            END SELECT

         END DO

         DEALLOCATE (rho_h, rho_s, weight)
         DEALLOCATE (vxc_LB_h, vxc_LB_s)
         DEALLOCATE (vxc_GLLB_h, vxc_GLLB_s)
         DEALLOCATE (vxc_tmp_h, vxc_tmp_s)
         DEALLOCATE (vxc_SAOP_h, vxc_SAOP_s)
         DEALLOCATE (drho_h, drho_s)

         CALL xc_dset_release(deriv_set)
         CALL xc_rho_set_release(rho_set_h)
         CALL xc_rho_set_release(rho_set_s)
         CALL xc_rho_set_release(orb_rho_set_h)
         CALL xc_rho_set_release(orb_rho_set_s)

      END DO

      ! remap pointer
      ns = SIZE(matrix_ks)
      ksmat(1:ns, 1:1) => matrix_ks(1:ns)
      ns = SIZE(rho_struct_ao)
      psmat(1:ns, 1:1) => rho_struct_ao(1:ns)

      CALL update_ks_atom(qs_env, ksmat, psmat, forces=.FALSE.)

      !---------!
      ! Cleanup !
      !---------!
      CALL section_vals_release(xc_fun_section_tmp)
      CALL section_vals_release(xc_section_tmp)
      CALL section_vals_release(xc_section_orig)

      CALL local_rho_set_release(local_rho_set)
      CALL cp_fm_release(single_mo_coeff)
      DEALLOCATE (mo_coeff, mo_eigenvalues)
      CALL dbcsr_deallocate_matrix_set(orbital_density_matrix)

   END SUBROUTINE gapw_add_atomic_saop_pot

! **************************************************************************************************
!> \brief ...
!> \param pot ...
!> \param rho_set ...
!> \param lsd ...
!> \param spin ...
! **************************************************************************************************
   SUBROUTINE add_lb_pot(pot, rho_set, lsd, spin)

      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: pot
      TYPE(xc_rho_set_type), INTENT(IN)                  :: rho_set
      LOGICAL, INTENT(IN)                                :: lsd
      INTEGER, INTENT(IN)                                :: spin

      REAL(KIND=dp), PARAMETER                           :: ob3 = 1.0_dp/3.0_dp

      INTEGER                                            :: i, j, k
      INTEGER, DIMENSION(2, 3)                           :: bo
      REAL(KIND=dp)                                      :: n, n_13, x, x2

      bo = rho_set%local_bounds

      DO k = bo(1, 3), bo(2, 3)
         DO j = bo(1, 2), bo(2, 2)
            DO i = bo(1, 1), bo(2, 1)
               IF (.NOT. lsd) THEN
                  IF (rho_set%rho(i, j, k) > rho_set%rho_cutoff) THEN
                     n = rho_set%rho(i, j, k)/2.0_dp
                     n_13 = n**ob3
                     x = (rho_set%norm_drho(i, j, k)/2.0_dp)/(n*n_13)
                     x2 = x*x
                     pot(i, j, k) = beta*x2*n_13/(1.0_dp + 3.0_dp*beta*x*LOG(x + SQRT(x2 + 1.0_dp)))
                  END IF
               ELSE
                  IF (spin == 1) THEN
                     IF (rho_set%rhoa(i, j, k) > rho_set%rho_cutoff) THEN
                        n_13 = rho_set%rhoa_1_3(i, j, k)
                        x = rho_set%norm_drhoa(i, j, k)/(rho_set%rhoa(i, j, k)*n_13)
                        x2 = x*x
                        pot(i, j, k) = beta*x2*n_13/(1.0_dp + 3.0_dp*beta*x*LOG(SQRT(x2 + 1.0_dp) + x))
                     END IF
                  ELSE IF (spin == 2) THEN
                     IF (rho_set%rhob(i, j, k) > rho_set%rho_cutoff) THEN
                        n_13 = rho_set%rhob_1_3(i, j, k)
                        x = rho_set%norm_drhob(i, j, k)/(rho_set%rhob(i, j, k)*n_13)
                        x2 = x*x
                        pot(i, j, k) = beta*x2*n_13/(1.0_dp + 3.0_dp*beta*x*LOG(SQRT(x2 + 1.0_dp) + x))
                     END IF
                  END IF
               END IF
            END DO
         END DO
      END DO

   END SUBROUTINE add_lb_pot

! **************************************************************************************************
!> \brief ...
!> \param pot ...
!> \param rho_set ...
!> \param e_uniform ...
!> \param lsd ...
! **************************************************************************************************
   SUBROUTINE calc_2excpbe(pot, rho_set, e_uniform, lsd)

      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: pot
      TYPE(xc_rho_set_type), INTENT(IN)                  :: rho_set
      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: e_uniform
      LOGICAL, INTENT(IN)                                :: lsd

      INTEGER                                            :: i, j, k
      INTEGER, DIMENSION(2, 3)                           :: bo
      REAL(KIND=dp)                                      :: e_unif, rho

      bo = rho_set%local_bounds

      DO k = bo(1, 3), bo(2, 3)
         DO j = bo(1, 2), bo(2, 2)
            DO i = bo(1, 1), bo(2, 1)
               IF (.NOT. lsd) THEN
                  IF (rho_set%rho(i, j, k) > rho_set%rho_cutoff) THEN
                     e_unif = e_uniform(i, j, k)/rho_set%rho(i, j, k)
                  ELSE
                     e_unif = 0.0_dp
                  END IF
                  pot(i, j, k) = &
                     2.0_dp* &
                     calc_ecpbe_r(rho_set%rho(i, j, k), rho_set%norm_drho(i, j, k), &
                                  e_unif, rho_set%rho_cutoff, rho_set%drho_cutoff) + &
                     2.0_dp* &
                     calc_expbe_r(rho_set%rho(i, j, k), rho_set%norm_drho(i, j, k), &
                                  rho_set%rho_cutoff, rho_set%drho_cutoff)
               ELSE
                  rho = rho_set%rhoa(i, j, k) + rho_set%rhob(i, j, k)
                  IF (rho > rho_set%rho_cutoff) THEN
                     e_unif = e_uniform(i, j, k)/rho
                  ELSE
                     e_unif = 0.0_dp
                  END IF
                  pot(i, j, k) = &
                     2.0_dp* &
                     calc_ecpbe_u(rho_set%rhoa(i, j, k), rho_set%rhob(i, j, k), rho_set%norm_drho(i, j, k), &
                                  e_unif, &
                                  rho_set%rho_cutoff, rho_set%drho_cutoff) + &
                     2.0_dp* &
                     calc_expbe_u(rho_set%rhoa(i, j, k), rho_set%rhob(i, j, k), rho_set%norm_drho(i, j, k), &
                                  rho_set%rho_cutoff, rho_set%drho_cutoff)
               END IF
            END DO
         END DO
      END DO

   END SUBROUTINE calc_2excpbe

! **************************************************************************************************
!> \brief ...
!> \param ra ...
!> \param rb ...
!> \param ngr ...
!> \param ec_unif ...
!> \param rc ...
!> \param ngrc ...
!> \return ...
! **************************************************************************************************
   FUNCTION calc_ecpbe_u(ra, rb, ngr, ec_unif, rc, ngrc) RESULT(res)

      REAL(kind=dp), INTENT(in)                          :: ra, rb, ngr, ec_unif, rc, ngrc
      REAL(kind=dp)                                      :: res

      REAL(kind=dp), PARAMETER                           :: ob3 = 1.0_dp/3.0_dp, tb3 = 2.0_dp/3.0_dp

      REAL(kind=dp)                                      :: A, At2, H, kf, kl, ks, phi, phi3, r, t2, &
                                                            zeta

      r = ra + rb
      H = 0.0_dp
      IF (r > rc .AND. ngr > ngrc) THEN
         zeta = (ra - rb)/r
         IF (zeta > 1.0_dp) zeta = 1.0_dp ! machine precision problem
         IF (zeta < -1.0_dp) zeta = -1.0_dp ! machine precision problem
         phi = ((1.0_dp + zeta)**tb3 + (1.0_dp - zeta)**tb3)/2.0_dp
         phi3 = phi*phi*phi
         kf = (3.0_dp*r*pi*pi)**ob3
         ks = SQRT(4.0_dp*kf/pi)
         t2 = (ngr/(2.0_dp*phi*ks*r))**2
         A = beta_ec/gamma_saop/(EXP(-ec_unif/(gamma_saop*phi3)) - 1.0_dp)
         At2 = A*t2
         kl = (1.0_dp + At2)/(1.0_dp + At2 + At2*At2)
         H = gamma_saop*LOG(1.0_dp + beta_ec/gamma_saop*t2*kl)
      END IF
      res = ec_unif + H

   END FUNCTION calc_ecpbe_u

! **************************************************************************************************
!> \brief ...
!> \param r ...
!> \param ngr ...
!> \param ec_unif ...
!> \param rc ...
!> \param ngrc ...
!> \return ...
! **************************************************************************************************
   FUNCTION calc_ecpbe_r(r, ngr, ec_unif, rc, ngrc) RESULT(res)

      REAL(kind=dp), INTENT(in)                          :: r, ngr, ec_unif, rc, ngrc
      REAL(kind=dp)                                      :: res

      REAL(kind=dp)                                      :: A, At2, H, kf, kl, ks, t2

      H = 0.0_dp
      IF (r > rc .AND. ngr > ngrc) THEN
         kf = (3.0_dp*r*pi*pi)**(1.0_dp/3.0_dp)
         ks = SQRT(4.0_dp*kf/pi)
         t2 = (ngr/(2.0_dp*ks*r))**2
         A = beta_ec/gamma_saop/(EXP(-ec_unif/gamma_saop) - 1.0_dp)
         At2 = A*t2
         kl = (1.0_dp + At2)/(1.0_dp + At2 + At2*At2)
         H = gamma_saop*LOG(1.0_dp + beta_ec/gamma_saop*t2*kl)
      END IF
      res = ec_unif + H

   END FUNCTION calc_ecpbe_r

! **************************************************************************************************
!> \brief ...
!> \param ra ...
!> \param rb ...
!> \param ngr ...
!> \param rc ...
!> \param ngrc ...
!> \return ...
! **************************************************************************************************
   FUNCTION calc_expbe_u(ra, rb, ngr, rc, ngrc) RESULT(res)

      REAL(kind=dp), INTENT(in)                          :: ra, rb, ngr, rc, ngrc
      REAL(kind=dp)                                      :: res

      REAL(kind=dp)                                      :: r

      r = ra + rb
      res = calc_expbe_r(r, ngr, rc, ngrc)

   END FUNCTION calc_expbe_u

! **************************************************************************************************
!> \brief ...
!> \param r ...
!> \param ngr ...
!> \param rc ...
!> \param ngrc ...
!> \return ...
! **************************************************************************************************
   FUNCTION calc_expbe_r(r, ngr, rc, ngrc) RESULT(res)

      REAL(kind=dp), INTENT(in)                          :: r, ngr, rc, ngrc
      REAL(kind=dp)                                      :: res

      REAL(kind=dp)                                      :: ex_unif, fx, kf, s

      IF (r > rc) THEN
         kf = (3.0_dp*r*pi*pi)**(1.0_dp/3.0_dp)
         ex_unif = -3.0_dp*kf/(4.0_dp*pi)
         fx = 1.0_dp
         IF (ngr > ngrc) THEN
            s = ngr/(2.0_dp*kf*r)
            fx = fx + kappa - kappa/(1.0_dp + mu*s*s/kappa)
         END IF
         res = ex_unif*fx
      ELSE
         res = 0.0_dp
      END IF

   END FUNCTION calc_expbe_r

END MODULE xc_pot_saop
